import glob
import os
import re
import json
from tqdm import tqdm
from typing import Tuple, Dict, List, Union


COG_MAP_DESCRIPTION_FOR_INPUT_SEP_OBJ_VIEW_OLD = '''[Cognitive Map Format]
We provide you a 2D grid map of the scene that is related to the question you should answer. Below is the description of the map:
- The map uses a 10x10 grid where [0,0] is at the top-left corner and [9,9] is at the bottom-right corner {birdview}
- Directions are defined as:
  * up = towards the top of the grid (decreasing y-value)
  * right = towards the right of the grid (increasing x-value)
  * down = towards the bottom of the grid (increasing y-value)
  * left = towards the left of the grid (decreasing x-value)
  * inner = straight into the 2D map (perpendicular to the grid, pointing away from you)
  * outer = straight out of the 2D map (perpendicular to the grid, pointing towards you)
- "objects" lists all important items in the scene with their positions
- "facing" indicates which direction an object is oriented towards (when applicable)
- "views" represents the different camera viewpoints in the scene
'''

COG_MAP_DESCRIPTION_FOR_INPUT_MERGED_VIEWS_OLD = '''[Cognitive Map Format]
We provide you a 2D grid map of the scene that is related to the question you should answer. Below is the description of the map:
- The map uses a 10x10 grid where [0,0] is at the top-left corner and [9,9] is at the bottom-right corner {birdview}
- Directions are defined as:
  * up = towards the top of the grid (decreasing y-value)
  * right = towards the right of the grid (increasing x-value)
  * down = towards the bottom of the grid (increasing y-value)
  * left = towards the left of the grid (decreasing x-value)
  * inner = straight into the 2D map (perpendicular to the grid, pointing away from you)
  * outer = straight out of the 2D map (perpendicular to the grid, pointing towards you)
- "facing" indicates which direction an object is oriented towards (when applicable)
- if name is "Image x" or "View x", it means the camera position and direction of the x-th image/view
'''

COG_MAP_DESCRIPTION_FOR_INPUT = '''[Cognitive Map Format]
We provide you a 2D grid map of the scene that is related to the question you should answer. Below is the description of the map:
- The map uses a 10x10 grid where [0,0] is at the top-left corner and [9,9] is at the bottom-right corner {birdview}
- Directions are defined as:
  * up = towards the top of the grid (decreasing y-value)
  * right = towards the right of the grid (increasing x-value)
  * down = towards the bottom of the grid (increasing y-value)
  * left = towards the left of the grid (decreasing x-value)
  * inner = straight into the 2D map (perpendicular to the grid, pointing away from you)
  * outer = straight out of the 2D map (perpendicular to the grid, pointing towards you)
- "objects" lists all important items in the scene with their positions
- "facing" indicates which direction an object is oriented towards (when applicable)
- "views" represents the different camera viewpoints in the scene
- "facing_objects" indicates the camera is facing which objects
'''

COG_MAP_DESCRIPTION_FOR_OUTPUT_SHORTEN = '''[Task]
Your task is to analyze the spatial arrangement of objects in the scene by examining the provided images, which show the scene from different viewpoints. You will then create a detailed cognitive map representing the scene using a **10x10 grid coordinate system**. 

[Rules]
1. Focus ONLY on these categories of objects in the scene: {{obj}}
2. Create a cognitive map with the following structure{birdview}:
   - A 10x10 grid where [0, 0] is at the top-left corner and [9, 9] is at the bottom-right corner
   - up = towards the top of the grid (decreasing y)
   - right = towards the right of the grid (increasing x)
   - down = towards the bottom of the grid (increasing y)
   - left = towards the left of the grid (decreasing x)
   - Include positions of all objects from the specified categories
   - Estimate the center location (coordinates [x, y]) of each instance within provided categories
   - If a category contains multiple instances, include all of them
   - Object positions must maintain accurate relative spatial relationships
   - Combine and merge information from the images since they are pointing to the same scene, calibrating the object locations with grid coordinates accordingly
3. Carefully integrate information from all views to create a single coherent spatial representation.
<orientation_info>

[Output]
1. Given the provided views and main objects mentioned in the above rules, you **MUST** present your cognitive map in the following JSON format **before your reasoning**:
{
    "object_category_1": {"position": [x, y]},
    "object_category_2": {"position": [x, y], "facing": "direction"}, # if the object is asked for orientation
    ...
}

2. Next, please also provide your reasons step by step in details, then provide *ONE* correct answer selecting from the options. Your response's format should be like "<CogMap>\n <Your cognitive map>\n<Reasoning>\n ... \n<Answer>\n Therefore, my answer is <selected option>". Your <selected option> must be in the format like "A. Above". Your option must be from the available options.'''

COG_MAP_DESCRIPTION_FOR_OUTPUT = '''[Task]
Your task is to analyze the spatial arrangement of objects in the scene by examining the provided images, which show the scene from different viewpoints. You will then create a detailed cognitive map representing the scene using a 10x10 grid coordinate system. 

[Rules]
1. Focus ONLY on these categories of objects in the scene: {{obj}}
2. Create a cognitive map with the following structure{birdview}:
   - A 10x10 grid where [0,0] is at the top-left corner and [9,9] is at the bottom-right corner
   - up = towards the top of the grid (decreasing y)
   - right = towards the right of the grid (increasing x)
   - down = towards the bottom of the grid (increasing y)
   - left = towards the left of the grid (decreasing x)
   - inner = straight into the 2D map (perpendicular to the grid, pointing away from you)
   - outer = straight out of the 2D map (perpendicular to the grid, pointing towards you)
   - Include positions of all objects from the specified categories
   - Estimate the center location (coordinates [x, y]) of each instance within provided categories
   - If a category contains multiple instances, include all of them
   - Each object's estimated location should accurately reflect its real position in the scene, preserving the relative spatial relationships among all objects
   - Combine and merge information from the images since they are pointing to the same scene, calibrating the object locations accordingly
   - Include camera positions and directions for each view
3. Carefully integrate information from all views to create a single coherent spatial representation.
<orientation_info>

[Output]
1. Given the provided views and main objects mentioned in the above rules, you **MUST** present your cognitive map in the following JSON format **before your reasoning**:
{
  "objects": [
    {"name": "object_name", "position": [x, y], "facing": "direction"},
    {"name": "object_without_orientation", "position": [x, y]}
  ],
  "views": [
    {"name": "View/Image 1", "position": [x, y], "facing": "direction"},
    {"name": "View/Image 2", "position": [x, y], "facing": "direction"}
  ]
}

2. Next, please also provide your reasons step by step in details, then provide *ONE* correct answer selecting from the options. Your response's format should be like "<CogMap>\n <Your cognitive map>\n<Reasoning>\n ... \n<Answer>\n Therefore, my answer is <selected option>". Your <selected option> must be in the format like "A. Above". Your option must be from the available options.'''

## ------------------Utility Functions------------------

def format_cogmap_json(cogmap):
    """
    Format a dictionary into a specific JSON string format.
    
    Args:
        cogmap: Dictionary with 'objects' and 'views' keys
        
    Returns:
        str: Formatted JSON string
    """
    result = "{\n"
    result += '  "objects": [\n'
    for i, obj in enumerate(cogmap["objects"]):
        result += '    ' + json.dumps(obj, ensure_ascii=False)
        if i < len(cogmap["objects"]) - 1:
            result += ','
        result += '\n'
    result += '  ],\n'
    
    result += '  "views": [\n'
    for i, view in enumerate(cogmap["views"]):
        result += '    ' + json.dumps(view, ensure_ascii=False)
        if i < len(cogmap["views"]) - 1:
            result += ','
        result += '\n'
    result += '  ]\n'
    result += '}'
    
    # Validate that the formatted JSON is equivalent to the original
    try:
        parsed_json = json.loads(result)
        # Check if the parsed JSON is equivalent to the original dictionary
        if parsed_json["objects"] != cogmap["objects"] or parsed_json["views"] != cogmap["views"]:
            print(f"Warning: Formatted JSON is not equivalent to the original dictionary")
            print(f"Original: {json.dumps(cogmap, indent=2, ensure_ascii=False, separators=(',', ': '))}")
            print(f"Formatted and parsed: {json.dumps(parsed_json, indent=2, ensure_ascii=False, separators=(',', ': '))}")
            raise ValueError("Formatted JSON is not equivalent to the original dictionary")
    except json.JSONDecodeError as e:
        print(f"Warning: Generated JSON is not valid: {e}")
        print(f"Original JSON: {json.dumps(cogmap, indent=2, ensure_ascii=False, separators=(',', ': '))}")
        raise e
    
    return result

def extract_around_image_names(images: List[str]) -> List[int]:
    """
    Extract the image names from the images list.
    """
    image_names = []
    for image in images:
        try:
            filename = os.path.basename(image)
        except Exception:
            filename = image
        match = re.search(r'(\d+)_frame(?:_[^.]+)?\.(?:png|jpg|jpeg)', filename)
        if match:
            ### special case for dl3dv10k
            num = int(match.group(1))
            if num == 33: # a typo in the dataset, changing it will cost a lot of time, so we just change it here
                num = 3
            image_names.append(num)
        else:
            raise ValueError(f"Error extracting image name from {image}")
    return image_names

def gen_object_coordinate_dict(name, position, facing=None):
    if name == "":
        return None
    if facing is None:
        return {"name": name, "position": position}
    else:
        return {"name": name, "position": position, "facing": facing}

## ------------------Generate Setting-Specific Cogmap Functions------------------

# def get_object_facing(object_name, ...) -> Union[str, None]:
#     ...

def generate_cogmap_around(item) -> Tuple[str, list, list]:
    """
    **Now this is the augmented version of the around setting, including the object orientation.**
    Generate the cogmap for around setting.
    return the cogmap into three parts:
     - the first part is the description of the cogmap, which is a string.
     - the second part is the cogmap, which is a dict string.
     - the third part is the objects that have None orientation, which is a list of dicts.
    """
    id = item.get("id", "")
    category = item.get("category", [])
    type = item.get("type", -1)
    meta_info = item.get("meta_info", [])
    question = item.get("question", "")
    question_images = item.get("images", [])
    gt_answer = item.get("answer", "")

    datasource = id.split("_")[0].replace("around", "")
    if datasource == "new":
        datasource = "self"
    else:
        datasource = "dl3dv10k"

    assert datasource in ["self", "dl3dv10k"], f"Unknown datasource: {datasource}"

    image_group_num = meta_info[0][0] # how many views in the entire image group
    object_len = meta_info[1][0]
    objects = meta_info[1][1] # the objects in the scene
    objects_orientation = meta_info[1][2] # orientation info of the objects

    orientation_mapping = {
        "face": "down",
        "back": "up",
        "right": "right",
        "left": "left",
        "None": None,
        "null": None,
        None: None,
    }

    assert object_len == len(objects), f"Object number {object_len} is not equal to objects length {len(objects)}, id: {id}"
    assert object_len == len(objects_orientation), f"Object number {object_len} is not equal to objects orientation length {len(objects_orientation)}, id: {id}"
    question_images_len = len(question_images) # how many images in the question
    assert isinstance(image_group_num, int), f"Image group number {image_group_num} is not an integer, id: {id}"
    assert isinstance(question_images_len, int), f"Question images length {question_images_len} is not an integer, id: {id}"
    assert question_images_len <= image_group_num, f"Question images length {question_images_len} is greater than image group number {image_group_num}, id: {id}"
    assert 3 <= image_group_num <= 6, f"Image group number {image_group_num} is not between 3 and 6"

    if image_group_num == 3 and datasource == "self":
        global_views = {1: "front", 2: "left", 3: "right"}
    elif image_group_num == 3 and datasource == "dl3dv10k":
        global_views = {1: "front", 2: "left", 3: "right", 4: "back"}
    elif image_group_num == 4:
        global_views = {1: "front", 2: "left", 3: "right", 4: "back"}
    elif image_group_num == 5:
        global_views = {1: "front", 2: "left", 3: "right", 4: "left", 5: "right"}
    elif image_group_num == 6 and datasource == "self":
        global_views = {1: "front", 2: "left", 3: "right", 4: "left", 5: "right", 6: "back"}
    elif image_group_num == 6 and datasource == "dl3dv10k":
        global_views = {1: "front", 2: "left", 3: "right", 4: "front", 5: "left", 6: "right"}
    else:
        raise ValueError(f"Unknown image group number: {image_group_num}, datasource: {datasource}")

    question_image_ids = extract_around_image_names(question_images)

    assert len(question_image_ids) == question_images_len, f"Question image ids length {len(question_image_ids)} is not equal to question images length {question_images_len}"
    # Ensure no repeated numbers in question_image_ids
    assert len(question_image_ids) == len(set(question_image_ids)), f"Duplicate image IDs found in question_image_ids: {question_image_ids}"
    # Validate that each image ID is within the valid range for this group
    for img_id in question_image_ids:
        assert 1 <= img_id <= len(global_views.items()), f"Image ID {img_id} is outside the valid range [1, {len(global_views.items())}]. question id: {id}"

    view_base_name = "Image" if "image" in question.lower() else "View"

    local_view_map_to_global_view_image_id = [
        (f"{view_base_name} {k+1}", v) for k, v in enumerate(question_image_ids)
    ] # [(View 1, g1), (View 2, g2), ...], g means global view image id

    # we have built the mapping relationship. now it is time to determine all the posibilities of the object positions.
    ## how to do this? well, we still can first determine the object arragement first, which is easy
    ## then we can determine the view positions. this one is a little tricky, and it is where we need to use our mapping relationship.

    # mapping to 2, 3, 4 objects coordinates
    mapping_view_to_coordinates = {
        "front": [[5, 6], [5, 6], [5, 6]], # facing up
        "left": [[3, 5], [3, 5], [2, 5]], # facing right
        "right": [[6, 5], [7, 5], [7, 5]], # facing left
        "back": [[5, 4], [5, 4], [5, 4]] # facing down
    }

    facing_mapping = {
        "front": "up",
        "left": "right",
        "right": "left",
        "back": "down"
    }

    ### for object, {2, 3, 4}. meaning 2 to 4 objects
    if object_len == 2: # (4, 5), (5, 5)
        object_coordinates = [
            gen_object_coordinate_dict(objects[0], [4, 5], orientation_mapping[objects_orientation[0]]), # eg. objects_orientation[0] = "face", orientation_mapping["face"] = "down"
            gen_object_coordinate_dict(objects[1], [5, 5], orientation_mapping[objects_orientation[1]])
        ]
    elif object_len == 3: # (4, 5), (5, 5), (6, 5)
        object_coordinates = [
            gen_object_coordinate_dict(objects[0], [4, 5], orientation_mapping[objects_orientation[0]]),
            gen_object_coordinate_dict(objects[1], [5, 5], orientation_mapping[objects_orientation[1]]),
            gen_object_coordinate_dict(objects[2], [6, 5], orientation_mapping[objects_orientation[2]])
        ]
    elif object_len == 4: # (3, 5), (4, 5), (5, 5), (6, 5)
        object_coordinates = [
            gen_object_coordinate_dict(objects[0], [3, 5], orientation_mapping[objects_orientation[0]]),
            gen_object_coordinate_dict(objects[1], [4, 5], orientation_mapping[objects_orientation[1]]),
            gen_object_coordinate_dict(objects[2], [5, 5], orientation_mapping[objects_orientation[2]]),
            gen_object_coordinate_dict(objects[3], [6, 5], orientation_mapping[objects_orientation[3]])
        ]
    
    mapping_index = object_len - 2 # eg. object_len = 2, mapping_index = 0
    view_coordinates = []
    for local_view_name, global_id in local_view_map_to_global_view_image_id:
        global_view = global_views[global_id] # eg. global_view = "front"
        view_coordinate = mapping_view_to_coordinates[global_view][mapping_index] # eg. global_view = "front", mapping_index = 0, view_coordinate = [5, 6]
        view_coordinates.append({
            "name": local_view_name, # eg. local_view_name = "View 1"
            "position": view_coordinate, # eg. view_coordinate = [5, 6]
            "facing": facing_mapping[global_view] # eg. facing_mapping["front"] = "up"
        })
    # filter out the objects that is None
    object_coordinates = [obj for obj in object_coordinates if obj is not None]
    cogmap = {
        "objects": object_coordinates,
        "views": view_coordinates
    }
    
    # gather objects that have None orientation and store them in a list
    oriented_objects = [obj['name'] for obj in object_coordinates if 'facing' in obj.keys()]
    return format_cogmap_json(cogmap), objects, oriented_objects

def generate_cogmap_among(item) -> Tuple[str, list, list]:
    """
    **Now this is the augmented version of the among setting, including the object orientation.**
    Generate the cogmap for among setting.
    """
    id = item.get("id", "")
    category = item.get("category", [])
    type = item.get("type", "")
    objects = item.get("meta_info", [])[0] # list of objects
    objects_orientation = item.get("meta_info", [])[1] # list of objects orientation
    question = item.get("question", "")
    images = item.get("images", [])

    assert len(objects) == 5, f"Among setting should have 5 objects, but got {len(objects)}, id: {id}"
    assert len(objects_orientation) == 5, f"Among setting should have 5 objects orientation, but got {len(objects_orientation)}"
    assert len(images) == 2 or len(images) == 4, f"Among setting should have 2 or 4 images, but got {len(images)}"

    orientation_mapping = {
        "face": "down",
        "back": "up",
        "right": "right",
        "left": "left",
        "None": None,
        "null": None,
        None: None,
    }

    image_names = [os.path.basename(image) for image in images]
    image_names = [name.split("_")[0] for name in image_names]

    view_base_name = "Image" if "image" in question.lower() else "View"

    for image_name in image_names:
        assert image_name in ["front", "left", "right", "back"], f"Unknown image name: {image_name}"
    
    local_view_map_to_global_view = [
        (f"{view_base_name} {k+1}", v) for k, v in enumerate(image_names)
    ] # here, v is the global view image name, eg. v = "front"

    mapping_view_to_coordinates = {
        "front": [5, 6],
        "left": [4, 5],
        "right": [6, 5],
        "back": [5, 4]
    }

    facing_mapping = {
        "front": "up",
        "left": "right",
        "right": "left",
        "back": "down"
    }

    object_coordinates = [
        gen_object_coordinate_dict(objects[0], [5, 5], orientation_mapping[objects_orientation[0]]),
        gen_object_coordinate_dict(objects[1], [5, 8], orientation_mapping[objects_orientation[1]]),
        gen_object_coordinate_dict(objects[2], [2, 5], orientation_mapping[objects_orientation[2]]),
        gen_object_coordinate_dict(objects[3], [5, 2], orientation_mapping[objects_orientation[3]]),
        gen_object_coordinate_dict(objects[4], [8, 5], orientation_mapping[objects_orientation[4]])
    ]
    
    # extract the image names
    view_coordinates = []
    for local_view_name, global_view_name in local_view_map_to_global_view: # eg. local_view_name = "View 1", global_view_name = "front"
        view_coordinate = mapping_view_to_coordinates[global_view_name] # eg. mapping_view_to_coordinates["front"] = [5, 4]
        view_facing = facing_mapping[global_view_name] # eg. facing_mapping["front"] = "up"
        view_coordinates.append({
            "name": local_view_name, # eg. local_view_name = "View 1"
            "position": view_coordinate, # eg. view_coordinate = [5, 4]
            "facing": view_facing # eg. view_facing = "up"
        })
    # filter out the objects that is None
    object_coordinates = [obj for obj in object_coordinates if obj is not None]

    cogmap = {
        "objects": object_coordinates,
        "views": view_coordinates
    }
    oriented_objects = [obj['name'] for obj in object_coordinates if 'facing' in obj.keys()]
    return format_cogmap_json(cogmap), objects, oriented_objects

def generate_cogmap_translation(item) -> Tuple[str, list, list]:
    """
    Generate the cogmap for translation setting.
    """
    id = item.get("id", "")
    category = item.get("category", [])
    type = item.get("type", "")
    meta_info = item.get("meta_info", [])
    question = item.get("question", "")
    images = item.get("images", [])
    gt_answer = item.get("answer", "")

    spatial_relations = meta_info[0]
    spatial_relation_list = spatial_relations.split(",")
    relation_1 = spatial_relation_list[0]
    relation_2 = spatial_relation_list[1] if len(spatial_relation_list) > 1 else relation_1

    #TODO: generate the cogmap
    # handle different types of spatial relationships
    # center is (5, 5), horizontal is x, vertical is y
    objects = meta_info[1:] # obj 1 is down to obj 2, obj 2 is down to obj 3
    if (relation_1, relation_2) == ('down', 'down'):
        # change y only
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [5, 7]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is down to obj 2
            gen_object_coordinate_dict(objects[2], [5, 3]) # obj 2 is down to obj 3
        ]
        view_0 = {"position": [5, 6], "facing": "inner"} # x
        view_1 = {"position": [5, 4], "facing": "inner"} # x
    elif (relation_1, relation_2) == ('right', 'right'):
        # change x only
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [7, 5]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is right to obj 2
            gen_object_coordinate_dict(objects[2], [3, 5]) # obj 2 is right to obj 3
        ]
        view_0 = {"position": [6, 6], "facing": "up"}
        view_1 = {"position": [4, 6], "facing": "up"}
    elif (relation_1, relation_2) == ('left', 'left'):
        # change x only
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [3, 5]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is left to obj 2
            gen_object_coordinate_dict(objects[2], [7, 5]) # obj 2 is left to obj 3
        ]
        view_0 = {"position": [4, 6], "facing": "up"} # ⬆️
        view_1 = {"position": [6, 6], "facing": "up"} # ⬆️
    elif (relation_1, relation_2) == ('front', 'down'):
        # change both x and y
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [7, 5]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is front to obj 2
            gen_object_coordinate_dict(objects[2], [5, 3]) # obj 2 is down to obj 3
        ]
        view_0 = {"position": [8, 5], "facing": "left"} # ⬅️
        view_1 = {"position": [5, 4], "facing": "inner"} # x
    elif (relation_1, relation_2) == ('right', 'down'):
        # change both x and y
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [7, 5]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is right to obj 2
            gen_object_coordinate_dict(objects[2], [5, 3]) # obj 2 is down to obj 3
        ]
        view_0 = {"position": [6, 6], "facing": "up"} # ⬆️
        view_1 = {"position": [5, 4], "facing": "inner"} # x
    elif (relation_1, relation_2) == ('front', 'front'):
        # change y only
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [5, 7]),
            gen_object_coordinate_dict(objects[1], [5, 5]),
            gen_object_coordinate_dict(objects[2], [5, 3])
        ]
        view_0 = {"position": [5, 8], "facing": "up"}
        view_1 = {"position": [5, 6], "facing": "up"}
    elif (relation_1, relation_2) == ('on', 'behind'):
        # change both x and y
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [5, 3]),
            gen_object_coordinate_dict(objects[1], [5, 5]), # obj 1 is on obj 2
            gen_object_coordinate_dict(objects[2], [3, 5]) # obj 2 is behind obj 3
        ]
        view_0 = {"position": [5, 4], "facing": "inner"} # x
        view_1 = {"position": [6, 5], "facing": "left"} # x
    elif (relation_1, relation_2) == ('on', 'on'):
        # change y only
        objects_coordnates = [
            gen_object_coordinate_dict(objects[0], [5, 3]),
            gen_object_coordinate_dict(objects[1], [5, 5]),
            gen_object_coordinate_dict(objects[2], [5, 7])
        ]
        view_0 = {"position": [5, 6], "facing": "inner"}
        view_1 = {"position": [5, 4], "facing": "inner"}
    else:
        raise ValueError(f"Unknown spatial relation: {relation_1}, {relation_2}")
    view_0_name_dict = {"name": "View 1"}
    view_1_name_dict = {"name": "View 2"}
    view_coordinates = []
    if 'inverse' in type:
        view_0_name_dict.update(view_1)
        view_1_name_dict.update(view_0)
        view_coordinates.append(view_0_name_dict)
        view_coordinates.append(view_1_name_dict)
    else:
        view_0_name_dict.update(view_0)
        view_1_name_dict.update(view_1)
        view_coordinates.append(view_0_name_dict)
        view_coordinates.append(view_1_name_dict)
    # filter out the objects that is None
    objects_coordnates = [obj for obj in objects_coordnates if obj is not None]
    cogmap = {
        "objects": objects_coordnates,
        "views": view_coordinates
    }   
    return format_cogmap_json(cogmap), objects, []

def generate_cogmap_rotation(item) -> Tuple[str, list, list]:
    """
    Generate the cogmap for rotation setting.
    """
    id = item.get("id", "")
    category = item.get("category", [])
    type = item.get("type", "")
    objects = item.get("meta_info", [])
    question = item.get("question", "")
    images = item.get("images", [])
    gt_answer = item.get("answer", "")
    # we have these configurations: {'three_view': 345, 'two_view_clockwise': 220, 'two_view_counterclockwise': 120, 'four_view': 360, 'two_view_opposite': 36}
    objects_coordinates = []
    view_coordinates = []

    view_base_name = "Image" if "image" in question.lower() else "View"

    if "two" in type:
        assert len(objects) == 2, f"Two objects are expected for two view rotation, but got {len(objects)}"
        if type == 'two_view_clockwise': # example, front to right
            objects_coordinates = [
                gen_object_coordinate_dict(objects[0], [5, 3]),
                gen_object_coordinate_dict(objects[1], [7, 5])
            ]
            view_coordinates = [
                {"name": f"{view_base_name} 1", "position": [5, 5], "facing": "up"},
                {"name": f"{view_base_name} 2", "position": [5, 5], "facing": "right"}
            ]
        elif type == 'two_view_counterclockwise': # example, front to left
            objects_coordinates = [
                gen_object_coordinate_dict(objects[0], [5, 3]),
                gen_object_coordinate_dict(objects[1], [3, 5])
            ]
            view_coordinates = [
                {"name": f"{view_base_name} 1", "position": [5, 5], "facing": "up"},
                {"name": f"{view_base_name} 2", "position": [5, 5], "facing": "left"}
            ]
        elif type == 'two_view_opposite': # opposite means 180 degree rotation
            objects_coordinates = [
                gen_object_coordinate_dict(objects[0], [5, 3]),
                gen_object_coordinate_dict(objects[1], [5, 7])
            ]
            view_coordinates = [
                {"name": f"{view_base_name} 1", "position": [5, 5], "facing": "up"},
                {"name": f"{view_base_name} 2", "position": [5, 5], "facing": "down"}
            ]
    elif "three" in type:
        assert len(objects) == 3, f"Three objects are expected for three view rotation, but got {len(objects)}"
        assert type == 'three_view', f"Unknown type: {type}" # clockwise, example, front to right to back
        objects_coordinates = [
            gen_object_coordinate_dict(objects[0], [3, 5]),
            gen_object_coordinate_dict(objects[1], [5, 3]),
            gen_object_coordinate_dict(objects[2], [7, 5])
        ]
        view_coordinates = [
            {"name": f"{view_base_name} 1", "position": [5, 5], "facing": "left"},
            {"name": f"{view_base_name} 2", "position": [5, 5], "facing": "up"},
            {"name": f"{view_base_name} 3", "position": [5, 5], "facing": "right"}
        ]
    elif "four" in type:
        assert len(objects) == 4, f"Four objects are expected for four view rotation, but got {len(objects)}"
        assert type == 'four_view', f"Unknown type: {type}" # example, front to right to back to left
        objects_coordinates = [
            gen_object_coordinate_dict(objects[0], [3, 5]),
            gen_object_coordinate_dict(objects[1], [5, 3]),
            gen_object_coordinate_dict(objects[2], [7, 5]),
            gen_object_coordinate_dict(objects[3], [5, 7])
        ]
        view_coordinates = [
            {"name": f"{view_base_name} 1", "position": [5, 5], "facing": "left"},
            {"name": f"{view_base_name} 2", "position": [5, 5], "facing": "up"},
            {"name": f"{view_base_name} 3", "position": [5, 5], "facing": "right"},
            {"name": f"{view_base_name} 4", "position": [5, 5], "facing": "down"}
        ]
    else:
        raise ValueError(f"Unknown type: {type}")
    
    cogmap = {
        "objects": objects_coordinates,
        "views": view_coordinates
    }
    
    return format_cogmap_json(cogmap), objects, []

## ------------------Investigate Functions------------------

def investigate_how_many_spatial_relations_in_translation(items):
    """
    Investigate how many spatial relations in the translation setting.
    """
    spatial_relation_count = {}
    for item in items:
        spatial_relations = item.get("meta_info", [])[0]
        if spatial_relations not in spatial_relation_count:
            spatial_relation_count[spatial_relations] = 0
        spatial_relation_count[spatial_relations] += 1
    return spatial_relation_count

def investigate_number_of_objects_in_translation(items):
    """
    Investigate how many objects in the translation setting.
    """
    object_count = {}
    for item in items:
        objects = item.get("meta_info", [])[1:]
        object_len = len(objects)
        if object_len not in object_count:
            object_count[object_len] = 0
        object_count[object_len] += 1
    return object_count

def investigate_rotation_type_types(items):
    types = {}
    for item in items:
        type = item.get("type", "")
        if type not in types:
            types[type] = 0
        types[type] += 1
    return types

def investigate_rotation_objects_num_equal_to_image_num(items):
    """
    Investigate how many objects in the rotation setting.
    """
    object_count = {"equal": 0, "not_equal": 0}
    for item in items:
        objects = item.get("meta_info", [])
        object_len = len(objects)
        image_len = len(item.get("images", []))
        if object_len == image_len:
            object_count["equal"] += 1
        else:
            object_count["not_equal"] += 1
    return object_count

def investigate_around_obj_number(items):
    obj_num_count = {}
    for item in items:
        meta_info = item.get("meta_info", [])
        object_len = meta_info[1][0]
        objects = meta_info[1][1]
        assert object_len == len(objects), f"Object number {object_len} is not equal to objects length {len(objects)}"
        if object_len not in obj_num_count:
            obj_num_count[object_len] = 0
        obj_num_count[object_len] += 1
    return obj_num_count

def investigate_around_obj_orientation(items):
    orientation_count = {}
    for item in items:
        meta_info = item.get("meta_info", [])
        objects_orientation = meta_info[1][2]
        for orientation in objects_orientation:
            if orientation not in orientation_count:
                orientation_count[orientation] = 0
            orientation_count[orientation] += 1
    return orientation_count

def investigate_among_obj_number(items):
    obj_num_count = {}
    for item in items:
        objects = item.get("meta_info", [])[0]
        object_len = len(objects)
        if object_len not in obj_num_count:
            obj_num_count[object_len] = 0
        obj_num_count[object_len] += 1
    return obj_num_count
        
def investigate_among_image_names(items):
    image_names_count = {}
    for item in items:
        images = item.get("images", [])
        image_names = [os.path.basename(image) for image in images]
        image_names = [name.split("_")[0] for name in image_names]
        for image_name in image_names:
            if image_name not in image_names_count:
                image_names_count[image_name] = 0
            image_names_count[image_name] += 1
    return image_names_count
    
def investigate_among_question_img_num(items):
    question_img_num_count = {}
    for item in items:
        question_img_num = len(item.get("images", []))
        if question_img_num not in question_img_num_count:
            question_img_num_count[question_img_num] = 0
        question_img_num_count[question_img_num] += 1
    return question_img_num_count

## ---------------Investigation over Cog Map Formats------------------

def json_merge_obj_views_str(new_cogmap):
    """
    Merge the objects and views in the cogmap string.
    """
    cogmap_str = '{\n'
    for obj in new_cogmap.keys():
        cogmap_str += f'  "{obj}": {json.dumps(new_cogmap[obj])}\n'
    cogmap_str += '}'
    return cogmap_str

def json_merge_obj_views_cogmap(cogmap_str):
    """
    Merge the objects and views in the cogmap string.
    """
    cogmap = json.loads(cogmap_str)
    objects = cogmap.get("objects", [])
    views = cogmap.get("views", [])
    new_cogmap = {}
    for obj in objects:
        if "facing" in obj.keys():
            new_cogmap[obj["name"]] = {"position": obj["position"], "facing": obj["facing"]}
        else:
            new_cogmap[obj["name"]] = {"position": obj["position"]}
    for view in views:
        if "facing" in view.keys():
            new_cogmap[view["name"]] = {"position": view["position"], "facing": view["facing"]}
        else:
            new_cogmap[view["name"]] = {"position": view["position"]}
    # sort the new_cogmap by the position. first by y, then by x
    # would like to arrange row by row, ie. the y is the same, then the x is the same
    # now write the code
    sorted_cogmap = {}
    for y in sorted(new_cogmap.keys(), key=lambda x: new_cogmap[x]["position"][1]):
        for x in sorted(new_cogmap.keys(), key=lambda x: new_cogmap[x]["position"][0]):
            sorted_cogmap[x] = new_cogmap[x]
    return json_merge_obj_views_str(sorted_cogmap)

def json_get_view_facing_objects(objects, view):
    # build relationship between objects and views according to positions, using ['view', 'relation', 'object]
    # relation has 4 types: up, down, left, right
    # if view facing is up, then all objects with y < view.y are facing objects
    # if view facing is down, then all objects with y > view.y are facing objects
    # if view facing is left, then all objects with x < view.x are facing objects
    # if view facing is right, then all objects with x > view.x are facing objects
    # if view facing is inner, then only consider its adjacent objects
    # now let's do the work
    view_position = view["position"]
    view_facing = view["facing"]
    view["facing_objects"] = []
    for obj in objects:
        obj_position = obj["position"]
        if view_facing == "up":
            if obj_position[1] < view_position[1]:
                view["facing_objects"].append(obj["name"])
        elif view_facing == "down":
            if obj_position[1] > view_position[1]:
                view["facing_objects"].append(obj["name"])
        elif view_facing == "left":
            if obj_position[0] < view_position[0]:
                view["facing_objects"].append(obj["name"])
        elif view_facing == "right":
            if obj_position[0] > view_position[0]:
                view["facing_objects"].append(obj["name"])
        elif view_facing == "inner":
            if obj_position[0] == view_position[0] or obj_position[1] == view_position[1]:
                view["facing_objects"].append(obj["name"])
        else:
            raise ValueError(f"Unknown facing: {view_facing}")
    return view
                

def json_add_facing_what_objs_for_views_in_cogmap(cogmap_str):
    """
    Add the facing what objects for views in the cogmap string.
    """
    cogmap = json.loads(cogmap_str)
    views = cogmap.get("views", [])
    for view in views:
        view = json_get_view_facing_objects(cogmap["objects"], view)
    return format_cogmap_json(cogmap)

def json_get_view_facing_main_obj(objects, view):
    """
    Get the facing what main object for views in the cogmap string.
    """
    view_position = view["position"]
    view_facing = view["facing"]
    view["facing_objects"] = []
    # if view facing is up, then all objects with y = view.y - 1 or y = view.y - 2 and x == view.x are facing objects
    # if view facing is down, then all objects with y = view.y + 1 or y = view.y + 2 and x == view.x are facing objects
    # if view facing is left, then all objects with x = view.x - 1 or x = view.x - 2 and y == view.y are facing objects
    # if view facing is right, then all objects with x = view.x + 1 or x = view.x + 2 and y == view.y are facing objects
    # if view facing is inner, then only consider its adjacent objects
    # if view facing is outer, then only consider its adjacent objects
    # now let's do the work
    for obj in objects:
        obj_pos = obj["position"]
        view_x, view_y = view_position
        obj_x, obj_y = obj_pos
        
        if view_facing == "up" and obj_x == view_x:
            # Check closest object first, then second closest if none found
            if obj_y == view_y - 1 or (obj_y == view_y - 2 and not view["facing_objects"]):
                view["facing_objects"].append(obj["name"])
        elif view_facing == "down" and obj_x == view_x:
            if obj_y == view_y + 1 or (obj_y == view_y + 2 and not view["facing_objects"]):
                view["facing_objects"].append(obj["name"])
        elif view_facing == "left" and obj_y == view_y:
            if obj_x == view_x - 1 or (obj_x == view_x - 2 and not view["facing_objects"]):
                view["facing_objects"].append(obj["name"])
        elif view_facing == "right" and obj_y == view_y:
            if obj_x == view_x + 1 or (obj_x == view_x + 2 and not view["facing_objects"]):
                view["facing_objects"].append(obj["name"])
        elif view_facing == "inner" and (obj_x == view_x or obj_y == view_y):
                view["facing_objects"].append(obj["name"])
    return view
    

def json_add_facing_what_main_obj_for_views_in_cogmap(cogmap_str):
    """
    Add the facing what main object for views in the cogmap string.
    """
    cogmap = json.loads(cogmap_str)
    views = cogmap.get("views", [])
    for view in views:
        view = json_get_view_facing_main_obj(cogmap["objects"], view)
    return format_cogmap_json(cogmap)

def json_facing_obj_merge(cogmap_str):
    """
    Merge the facing objects for views in the cogmap string.
    """
    cogmap = json.loads(cogmap_str)
    objects = cogmap.get("objects", [])
    views = cogmap.get("views", [])
    new_cogmap = {}
    for obj in objects:
        if "facing" in obj.keys():
            new_cogmap[obj["name"]] = {"position": obj["position"], "facing": obj["facing"]}
        else:
            new_cogmap[obj["name"]] = {"position": obj["position"]}
    for view in views:
        new_cogmap[view["name"]] = {"position": view["position"], "facing": view["facing"], "facing_objects": view["facing_objects"]}
    
    # sort the new_cogmap by the position. first by y, then by x
    # would like to arrange row by row, ie. the y is the same, then the x is the same
    # now write the code
    sorted_cogmap = {}
    for y in sorted(new_cogmap.keys(), key=lambda x: new_cogmap[x]["position"][1]):
        for x in sorted(new_cogmap.keys(), key=lambda x: new_cogmap[x]["position"][0]):
            sorted_cogmap[x] = new_cogmap[x]
    return json_merge_obj_views_str(sorted_cogmap)

## ------------------Main Entrance------------------

def generate_cogmap(input_file_path, output_file_path):
    """
    Main entrance for generating cognitive map for all four settings.
    """
    output_dir = os.path.dirname(output_file_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    # classify the questions into 4 settings based on the loaded data id
    all_questions = {
        "around": [],
        "among": [],
        "translation": [],
        "rotation": []
    }
    
    # 保持原始顺序
    all_items = []
    
    try:
        with open(input_file_path, 'r', encoding='utf-8') as f_in:
            for line in f_in:
                if line.strip():
                    item = json.loads(line)
                    all_items.append(item)  # 保存原始顺序
                    
                    # Classify the item into the appropriate category
                    id = item.get("id", "")
                    if "around" in id:
                        all_questions["around"].append(item)
                    elif "among" in id:
                        all_questions["among"].append(item)
                    elif "translation" in id:
                        all_questions["translation"].append(item)
                    elif "rotation" in id:
                        all_questions["rotation"].append(item)
                    else:
                        print(f"Warning: Question with id {id} does not match any known setting")
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_file_path}") 
        return
    except json.JSONDecodeError as e:
        print(f"Error: Could not decode JSON from input file {input_file_path}. Details: {e}")
        return
    
    # Print statistics about the loaded questions
    total_questions = sum(len(questions) for questions in all_questions.values())
    print(f"Total number of questions: {total_questions}")
    print(f"Found {len(all_questions['around'])} around questions, {len(all_questions['among'])} among questions, {len(all_questions['translation'])} translation questions, {len(all_questions['rotation'])} rotation questions")


    # generate the cogmap for each setting
    allowed_settings = [
        "around", 
        "among", 
        "translation", 
        "rotation"
    ]
    
    # 处理每个setting内的item，但不会改变原始顺序
    for setting in all_questions:
        if setting not in allowed_settings:
            continue
            
        func_name = f"generate_cogmap_{setting}"
        cogmap_func = globals().get(func_name, None)
        assert callable(cogmap_func), f"Function {func_name} is not callable"
        
        for item in all_questions[setting]:
            cogmap_str, main_objects, oriented_objects = cogmap_func(item)
            
            birdview_input_str = "" if setting == "translation" else "\n- The map is shown in the bird's view"
            birdview_output_str = "" if setting == "translation" else " in the bird's view"

            orientation_info = f"4. For objects [{' '.join(oriented_objects)}], determine their facing direction as up, right, down, or left. For other objects, omit the facing direction." if len(oriented_objects) > 0 else ""

            # add the cogmap to the item
            item["cogmap"] = cogmap_str
            item["cogmap_input"] = COG_MAP_DESCRIPTION_FOR_INPUT.replace("{birdview}", birdview_input_str)
            
            item["cogmap_output"] = COG_MAP_DESCRIPTION_FOR_OUTPUT.replace("{birdview}", birdview_output_str)
            # replace the {obj} in the cogmap_output with the main_objects_str, and {ori_obj} with placeholder "None"
            main_objects_str = ", ".join(main_objects)
            item["cogmap_output"] = item["cogmap_output"].replace("<orientation_info>", orientation_info)
            item["cogmap_output"] = item["cogmap_output"].replace("{obj}", main_objects_str)

    # 按原始顺序写入输出文件
    with open(output_file_path, 'w', encoding='utf-8') as f_out:
        for item in all_items:
            id = item.get("id", "")
            setting = None
            if "around" in id:
                setting = "around"
            elif "among" in id:
                setting = "among"
            elif "translation" in id:
                setting = "translation"
            elif "rotation" in id:
                setting = "rotation"
            
            if setting in allowed_settings:
                f_out.write(json.dumps(item) + "\n")


if __name__ == "__main__":
    input_root_dir = "./data/raw"
    output_root_dir = "./data/cog_map/prompt_full_map"

    input_file_paths = [
        "/path/to/crossviewQA_tinybench.jsonl",
        "/path/to/crossviewQA_tinybench.jsonl"
    ]
    for input_file_path in input_file_paths:
        print(f"Processing {input_file_path}")
        output_file_path = os.path.join(output_root_dir, os.path.basename(input_file_path))
        if os.path.exists(output_file_path):
            print(f"Skipping {input_file_path} because it already exists")
            continue
        generate_cogmap(input_file_path, output_file_path)